package edu.uky.ai.ml;

import java.text.DecimalFormat;
import java.util.HashMap;
import java.util.Random;

import edu.uky.ai.ml.nn.Database;
import edu.uky.ai.ml.nn.Edge;
import edu.uky.ai.ml.nn.Example;
import edu.uky.ai.ml.nn.Layer;
import edu.uky.ai.ml.nn.NeuralNetwork;
import edu.uky.ai.ml.nn.Neuron;

/**
 * Creates and trains a neural network to recognize capital letters.
 * 
 * @author Your Name
 */
public class Main {
	
	/** Prints doubles with 3 digits before the decimal and 4 after */
	private static final DecimalFormat decimalFormat = new DecimalFormat("000.0000");
	
	/** Prints doubles with 3 digits before the decimal and 0 after */
	private static final DecimalFormat percentFormat = new DecimalFormat("000"); 

	/**
	 * Creates and trains a neural network, then creates a GUI that allows you
	 * to interact with the network to see its predictions.
	 * 
	 * @param args ignored
	 */
	public static void main(String[] args) {
		// Create the training database (in this case, 5x7 pixel images of
		// capital letters).
		Database database = new AlphabetDatabase();
		// Creates a neural network with random edge weights and:
		// 36 input nodes, one for each of the 5x7 pixels in the image
		// 26 output nodes, one for each letter
		// 1 hidden layer with 10 nodes
		NeuralNetwork network = new NeuralNetwork(5*7, 26, 1, 10, new Random(0));
		// Print the initial error and accuracy of the network.
		System.out.println("Starting error:    " + decimalFormat.format(network.getError(database)));
		System.out.println("Starting accuracy: " + percentFormat.format(network.getAccuracy(database) * 100) + "%");
		// Train the network to recognize the training examples.
		learn(network, database);
		// Print the final error and accuracy of the network.
		System.out.println("Final error:       " + decimalFormat.format(network.getError(database)));
		System.out.println("Final accuracy:    " + percentFormat.format(network.getAccuracy(database) * 100) + "%");
		// Create a GUI to interact with the network.
		new AlphabetFrame(database, network);
	}
	
	/** The highest acceptable error for the network */
	public static final double ERROR_THRESHOLD = 0.05;
	
	/** The maximum number of training iterations to be run */
	public static final double MAX_ITERATIONS = 10000;
	
	/**
	 * Given a neural network and a database of training examples, this method
	 * adjusts the weights in the network to recognize the examples.  It runs
	 * until either the total error of the network is below a certain threshold
	 * or until a maximum number of training iterations have been run.
	 * 
	 * @param network the network to train
	 * @param database the examples on which to train the network
	 */
	private static void learn(NeuralNetwork network, Database database) {
		// This method should run no more than MAX_ITERATIONS number of times.
		
		// Get the total error of the network using
		// NeuralNetwork#getError(Database).
		double error;
		// If the total error is below ERROR_THRESHOLD, print a message
		// indicating success and return.
		
		// Otherwise, do a single training iteration.  A single iteration
		// is defined as back propagating once for each example in the
		// database.  Use #backpropagate(NeuralNetwork, Example) for this.
		
		// Every 100 iterations, print a message indicating the network's
		// current total error.  This is just so we don't get bored while
		// waiting for the network to train.
		
		// If we ran MAX_ITERATIONS iterations without going below the error
		// threshold, print a message indicating failure.
		
	}
	
	/**
	 * Given a neural network and a single training example, this method
	 * adjusts the weights in the network so that it becomes more accurate
	 * at recognizing the example.
	 * 
	 * @param network the network to adjust
	 * @param example the example
	 */
	private static void backpropagate(NeuralNetwork network, Example example) {
		// First, we need to see what output the network is currently giving
		// for this example.  Use NeuralNetwork#setInput(double[]) and
		// NeuralNetwork#getOutput() to get the network's current output.
		double[] result;
		// To do back propagation, we need to calculate error (delta) values
		// for each neuron.  This hash table will hold the delta values for
		// each node.
		HashMap<Neuron, Double> deltas = new HashMap<>();
		// First calculate the delta values for the output neurons.  Because
		// these neurons have no children, their delta value calculations are
		// slightly different from the others.  Use #delta(Neuron, double)
		// for this calculation.  Store the value in the deltas hash table.
		
		// Now we need to move backwards through the network, starting at the
		// first hidden layer and stopping once we reach the input layer.
		// Note that Layer#previous points to the layer immediately before the
		// current layer.
		
		// We need to calculate the delta value for each neuron in the
		// current layer.  Use #delta(Neuron, HashMap<Neuron, Double>) for
		// this.  Store the value in the deltas hash table.
		
		// Now we need to adjust the weights of each edge in the network using
		// the delta values we calculated above.  Loop through each edge in
		// NeuralNetwork#edges and adjust Edge#weight.  The amount to add to
		// the weight is the parent neuron's value times the child neuron's
		// delta value.
	
	}
	
	/**
	 * Given an output neuron and its expected values, this method returns the
	 * delta value for that neuron.
	 * 
	 * @param output the output neuron
	 * @param expected its expected value
	 * @return the delta value
	 */
	private static final double delta(Neuron output, double expected) {
		// First calculate the gradient of the neuron's activation function.
		// In this case, our neurons use the sigmoid activation function.  If n
		// is the neuron's current value, the gradient is n * (1 - n).  You can
		// use Neuron#getValue() to get a neuron's current value.
		double gradient;
		// Now calculate the neuron's error.  Error is simply the difference
		// between the expected value and the actual value.
		double error;
		// An neuron's delta value is its error times its gradient.
		return 0;
	}
	
	/**
	 * Given a hidden neuron and the current table of delta values, this method
	 * calculates the neuron's delta value.  This method assumes that the delta
	 * values for all neurons in layers forward of this neuron's layer have
	 * already been calculated.
	 * 
	 * @param hidden the hidden neuron
	 * @param deltas a table of delta values for other neurons
	 * @return this neuron's delta value
	 */
	private static final double delta(Neuron hidden, HashMap<Neuron, Double> deltas) {
		// First calculate the gradient of the neuron's activation function.
		// This is the same as in the method above.
		double gradient;
		// The error value of a hidden neuron is based on a sum of all its
		// outgoing edges.  Specifically, for each child neuron, we need to
		// add to the total error the edge's weight times the delta value of
		// the child neuron.
		double error;
		// An neuron's delta value is its error times its gradient.
		return 0;
	}
}
